{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class ResNetBlock(nn.Module): # <1>\n",
    "\n",
    "    def __init__(self, dim):\n",
    "        super(ResNetBlock, self).__init__()\n",
    "        self.conv_block = self.build_conv_block(dim)\n",
    "\n",
    "    def build_conv_block(self, dim):\n",
    "        conv_block = []\n",
    "\n",
    "        conv_block += [nn.ReflectionPad2d(1)]\n",
    "\n",
    "        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),\n",
    "                       nn.InstanceNorm2d(dim),\n",
    "                       nn.ReLU(True)]\n",
    "\n",
    "        conv_block += [nn.ReflectionPad2d(1)]\n",
    "\n",
    "        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),\n",
    "                       nn.InstanceNorm2d(dim)]\n",
    "\n",
    "        return nn.Sequential(*conv_block)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = x + self.conv_block(x) # <2>\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNetGenerator(nn.Module):\n",
    "\n",
    "    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> \n",
    "\n",
    "        assert(n_blocks >= 0)\n",
    "        super(ResNetGenerator, self).__init__()\n",
    "\n",
    "        self.input_nc = input_nc\n",
    "        self.output_nc = output_nc\n",
    "        self.ngf = ngf\n",
    "\n",
    "        model = [nn.ReflectionPad2d(3),\n",
    "                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),\n",
    "                 nn.InstanceNorm2d(ngf),\n",
    "                 nn.ReLU(True)]\n",
    "\n",
    "        n_downsampling = 2\n",
    "        for i in range(n_downsampling):\n",
    "            mult = 2**i\n",
    "            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,\n",
    "                                stride=2, padding=1, bias=True),\n",
    "                      nn.InstanceNorm2d(ngf * mult * 2),\n",
    "                      nn.ReLU(True)]\n",
    "\n",
    "        mult = 2**n_downsampling\n",
    "        for i in range(n_blocks):\n",
    "            model += [ResNetBlock(ngf * mult)]\n",
    "\n",
    "        for i in range(n_downsampling):\n",
    "            mult = 2**(n_downsampling - i)\n",
    "            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n",
    "                                         kernel_size=3, stride=2,\n",
    "                                         padding=1, output_padding=1,\n",
    "                                         bias=True),\n",
    "                      nn.InstanceNorm2d(int(ngf * mult / 2)),\n",
    "                      nn.ReLU(True)]\n",
    "\n",
    "        model += [nn.ReflectionPad2d(3)]\n",
    "        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n",
    "        model += [nn.Tanh()]\n",
    "\n",
    "        self.model = nn.Sequential(*model)\n",
    "\n",
    "    def forward(self, input): # <3>\n",
    "        return self.model(input)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "source": [
    "netG = ResNetGenerator()"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "source": [
    "model_path = '../data/p1ch2/horse2zebra_0.4.0.pth'\n",
    "model_data = torch.load(model_path)\n",
    "netG.load_state_dict(model_data)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "source": [
    "netG.eval()"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "source": [
    "from PIL import Image\n",
    "from torchvision import transforms"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "source": [
    "preprocess = transforms.Compose([transforms.Resize(256),\n",
    "                                 transforms.ToTensor()])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "source": [
    "img = Image.open(\"../data/p1ch2/horse.jpg\")\n",
    "img"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "source": [
    "img_t = preprocess(img)\n",
    "batch_t = torch.unsqueeze(img_t, 0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "source": [
    "batch_out = netG(batch_t)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "source": [
    "out_t = (batch_out.data.squeeze() + 1.0) / 2.0\n",
    "out_img = transforms.ToPILImage()(out_t)\n",
    "# out_img.save('../data/p1ch2/zebra.jpg')\n",
    "out_img"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
